Amazon SageMakerの学習ジョブ完了通知bot(Slack)を作ってみた
こんにちは、大阪DI部の大澤です。
Amazon SageMakerの学習ジョブが完了した際にジョブの情報をSlackのチャンネルに投稿するbotを作ってみました。今回はその内容をご紹介します。
動機
- 自分が回している学習ジョブの完了通知が欲しい
- 他の人がどういう学習ジョブを回しているかを知りたい
仕組み
- 学習ジョブを開始します。
- 学習ジョブが完了し、モデルアーティファクトが
s3://<bucket-name>/<prefix>/<job-name>/output/model.tar.gz
に保存されます。 output/model.tar.gz
という接尾辞の条件を満たすキーのオブジェクトが保存されたことをトリガーとして、Lambda関数が実行されます。- Lambda関数でモデルアーティファクトのオブジェクトキーに含まれる学習ジョブ名
から学習ジョブ情報等を取得&整形し、SlackのIncoming Webhooksの専用URLへPOSTします。 - Slackに学習ジョブ情報が投稿されます。
※ 学習ジョブの完了をLambdaに通知する良い方法が思いつかなかったので、モデルアーティファクトの保存パスのルールを利用しました。なので、学習ジョブが失敗した時などモデルアーティファクトが保存されない場合には対応していません。
※ 全てのS3バケットに対応する訳ではありません。Lambda関数を実行するイベントトリガーを設定したS3バケットに対応します。複数のバケットに対応する場合はそれぞれイベントトリガーの設定が必要になります。
やってみる
Slack
SlackのワークスペースにIncoming Webhooksを追加し、メッセージを送信するために使用するエンドポイントURLを取得します。
Lambda関数の作成
次にマネジメントコンソールを開き、Lambda関数を作成します。 スクラッチを選択し、名前やランタイムや実行ロールを設定します。
- ランタイムはPython3.7にします。
- 実行ロールに今回は3つのポリシー
AWSPriceListServiceFullAccess
、AWSLambdaBasicExecutionRole
、AmazonSageMakerReadOnly
を付けています。- 実際にはもう少し権限は絞れますが、AWS管理ポリシーで楽をしました。
処理の紹介
Lambda関数で行う処理の流れは次のような感じです。
- モデルアーティファクトが保存されたオブジェクトキーから学習ジョブ名を取得します。
<prefix>/<job名>/output/model.tar.gz
という想定
- 学習ジョブ情報等を取得し、投稿用テキストを作成します。
- SageMakerのdescribe_training_jobで学習ジョブの詳細情報を取得し、Pricingのget_productsでインスタンスの価格情報を取得します。
- 取得した情報をもとにステータスや概算費用、学習時間、インスタンスタイプ、入力データなどのデータを整形し、投稿用テキストを作成します。
- Slack Incoming Webhooksの専用のエンドポイントURLに投稿データをPOSTします。
- 投稿データには学習ジョブ情報を記載したメッセージテキスト、通知先チャンネル、投稿時の表示名が含まれます。
以下がそのスクリプトになります。
import os import json import urllib import boto3 from datetime import timezone, timedelta # パラメータを読み込む BOT_USERNAME = os.environ['BOT_USERNAME'] SLACK_URL = os.environ['SLACK_URL'] SLACK_CHANNEL = os.environ['SLACK_CHANNEL'] # インスタンスタイプとリージョンに応じた価格取得時に指定する地域名(雑だけどもいい方法がないので、とりあえず辞書で対応) REGION_MAP = { 'us-gov-west-1' : 'AWS GovCloud (US)', 'ap-south-1' : 'Asia Pacific (Mumbai)', 'ap-northeast-2' : 'Asia Pacific (Seoul)', 'ap-southeast-1' : 'Asia Pacific (Singapore)', 'ap-southeast-2' : 'Asia Pacific (Sydney)', 'ap-northeast-1' : 'Asia Pacific (Tokyo)', 'ca-central-1' : 'Canada (Central)', 'eu-central-1' : 'EU (Frankfurt)', 'eu-west-1' : 'EU (Ireland)', 'eu-west-2' : 'EU (London)', 'us-east-1' : 'US East (N. Virginia)', 'us-east-2' : 'US East (Ohio)', 'us-west-1' : 'US West (N. California)', 'us-west-2' : 'US West (Oregon)', } # JSTを定義する JST = timezone(timedelta(hours=9), 'JST') def lambda_handler(event, context): """Lambdaから呼ばれるコールバック関数""" object_key = event['Records'][0]['s3']['object']['key'] # "<prefix>/<job名>/output/model.tar.gz"となる想定で、学習ジョブ名を取得する job_name = object_key.split('/')[-3] # バケットのあるリージョン region = event['Records'][0]['awsRegion'] # ジョブ情報を作成 job_info = create_training_job_info(job_name, region) # Slackにメッセージを送信する send_slack_message(job_info) return { 'statusCode': 200, 'body': json.dumps('Hello from Lambda!') } def send_slack_message(text): """Slackにメッセージを送信する""" data = { 'username':BOT_USERNAME, # 表示名 'text':text, # 内容 'channel': SLACK_CHANNEL # 送信先チャンネル } method = "POST" headers = {"Content-Type" : "application/json"} req = urllib.request.Request(SLACK_URL, method=method, data=json.dumps(data).encode(), headers=headers) with urllib.request.urlopen(req) as res: body = res.read() return body def create_training_job_info(job_name, region): """ジョブ名とリージョンから学習ジョブ情報を作成する""" sm = boto3.client('sagemaker', region_name=region) # ジョブデータを取得する job_data = sm.describe_training_job(TrainingJobName=job_name) # ジョブ状態 job_status = job_data['TrainingJobStatus'] # インスタンスタイプ instance_type = job_data['ResourceConfig']['InstanceType'] # ジョブ実行時間 job_total_seconds = (job_data['TrainingEndTime'] - job_data['TrainingStartTime']).total_seconds() job_time_text = create_time_text(job_total_seconds) # インスタンス価格を取得する instance_price = get_instance_price(region, instance_type) # メトリクス(lambdaのbotocoreのバージョンが古いため、未対応。今後に期待) metric_data = create_metric_text(job_data['FinalMetricDataList']) if 'FinalMetricDataList' in job_data else 'データなし' # 入力データ input_data = create_input_data_text(job_data['InputDataConfig']) # 学習コストの概算 job_cost = job_total_seconds / 3600 * float(instance_price) job_info = ''' ジョブ名: {job_name} ステータス: {job_status} 概算費用(USD): ${job_cost:,.3f} 学習時間: {job_time}({job_start_time}〜{job_end_time}) インスタンスタイプ: {instance_type} ({instance_price:,.3f} USD/h) モデルアーティファクト: {model_artifacts_uri} 入力データ: {input_data} メトリクス(未対応): {metric_data} ジョブ詳細: <https://{region}.console.aws.amazon.com/sagemaker/home?region={region}#/jobs/{job_name}> '''.format( region = region, job_status = job_status, job_cost = job_cost, job_name = job_name, instance_type = instance_type, instance_price = float(instance_price), job_start_time = job_data['TrainingStartTime'].astimezone(JST).strftime('%Y/%m/%d %H:%M:%S'), job_end_time = job_data['TrainingEndTime'].astimezone(JST).strftime('%Y/%m/%d %H:%M:%S'), job_time = job_time_text, input_data = input_data, model_artifacts_uri = job_data['ModelArtifacts']['S3ModelArtifacts'], metric_data = metric_data ) return job_info def get_first_element(dict_obj): """辞書の最初の要素を取得する""" return next(iter(dict_obj.values())) def get_instance_price(region, instance_type): """リージョンとインスタンスタイプに応じた価格を取得する""" # 対応していないリージョンがあるため、決め打ちでus-east-1を使う pricing = boto3.client('pricing', region_name='us-east-1') # 対応するリージョンとインスタンスタイプの料金を取得する response = pricing.get_products( ServiceCode='AmazonSageMaker', Filters=[ { 'Type': 'TERM_MATCH', 'Field': 'location', 'Value': REGION_MAP[region] }, { 'Type': 'TERM_MATCH', 'Field': 'instanceType', 'Value': instance_type + '-Training' } ] ) return get_first_element(get_first_element(json.loads(response['PriceList'][0])['terms']['OnDemand'])['priceDimensions'])['pricePerUnit']['USD'] def create_time_text(total_seconds): """秒数から読みやすい時間表記を作る""" days, remainder = divmod(total_seconds, 86400) hours, remainder = divmod(remainder, 3600) minutes, seconds = divmod(remainder, 60) time_text = '' if days > 0: time_text += str(int(days))+'日' if hours > 0: time_text += str(int(hours))+'時間' if minutes > 0: time_text += str(int(minutes))+'分' time_text += str(int(seconds))+'秒' return time_text def create_metric_text(metric_list): """メトリクスデータを作成する""" text = '' for metric_data in metric_list: text += ' {metric_name}: {metric_value}\n'.format( metric_name = metric_data['MetricName'], metric_value = metric_data['Value'] ) return text def create_input_data_text(input_data_config): """入力データ情報を作成する""" if len(input_data_config) == 0: # 強化学習の場合などはデータが無い return 'データなし' text = '' for input_data in input_data_config: text += ' {channel_name}: {s3uri}\n'.format( channel_name = input_data['ChannelName'], s3uri = input_data['DataSource']['S3DataSource']['S3Uri'] ) return text
※ スクリプト中に学習の最後のメトリクスデータを表示できるような処理を含んでいますが、機能しません。Lambdaのランタイム環境のbotocoreのバージョンが低いため、describe_training_jobがFinalMetricDataListに対応していなかったためです。そのうち更新されるだろうという希望のもと、参考程度に残しています。(2019/1/22現在)
環境変数の設定
スクリプトで使用するパラメータを環境変数として設定します。
- BOT_USERNAME: Slackへ投稿時に表示される名前です。
- SLACK_CHANNEL: 投稿先チャンネル名です。
- SLACK_URL: Incoming Webhooks設定時に取得したエンドポイントURLです。このURLにデータをPOSTすることでSlackに投稿することができます。
トリガーの設定
次にLambda関数のトリガーを設定します。
S3の特定のバケットに接尾辞がoutput/model.tar.gz
のオブジェクトが保存されるとLambda関数が実行されるように設定し、イベントとして追加します。イベントの追加は対応させるバケットの数だけ必要になります。
イベントの追加が終われば必要な設定は終わりです。右上のSaveでLambda関数を保存します。
動作確認
これまでに動かした学習ジョブのモデルアーティファクトをダウンロードし、同じ場所へアップロードし直すと学習時を再現できます。 実際にやってみると、こんな感じでSlackへ投稿されます。
情報が不足していたり、多かったり、通知内容が見にくかったりすると思うので、必要に応じてスクリプトを修正してください〜。
さいごに
Amazon SageMakerで学習ジョブが完了した際に学習ジョブ情報をSlackに通知するbotを作ってみました。少々強引な方法になってしまいましたが、それっぽいものは出来ました。これでマネジメントコンソールに入らなくても、どういうジョブが行われたかが分かります。Slackへの通知部分を書き換えることで別のサービスとの連携も可能です。使い方によっては便利なものになるかもしれないですね。
お読みくださり、ありがとうございました〜!